{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "# Neural style transfer\n",
    "\n",
    "**Author:** [fchollet](https://twitter.com/fchollet)<br>\n",
    "**Date created:** 2016/01/11<br>\n",
    "**Last modified:** 2020/05/02<br>\n",
    "**Description:** Transfering the style of a reference image to target image using gradient descent."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Introduction\n",
    "\n",
    "Style transfer consists in generating an image\n",
    "with the same \"content\" as a base image, but with the\n",
    "\"style\" of a different picture (typically artistic).\n",
    "This is achieved through the optimization of a loss function\n",
    "that has 3 components: \"style loss\", \"content loss\",\n",
    "and \"total variation loss\":\n",
    "\n",
    "- The total variation loss imposes local spatial continuity between\n",
    "the pixels of the combination image, giving it visual coherence.\n",
    "- The style loss is where the deep learning keeps in --that one is defined\n",
    "using a deep convolutional neural network. Precisely, it consists in a sum of\n",
    "L2 distances between the Gram matrices of the representations of\n",
    "the base image and the style reference image, extracted from\n",
    "different layers of a convnet (trained on ImageNet). The general idea\n",
    "is to capture color/texture information at different spatial\n",
    "scales (fairly large scales --defined by the depth of the layer considered).\n",
    "- The content loss is a L2 distance between the features of the base\n",
    "image (extracted from a deep layer) and the features of the combination image,\n",
    "keeping the generated image close enough to the original one.\n",
    "\n",
    "**Reference:** [A Neural Algorithm of Artistic Style](\n",
    "  http://arxiv.org/abs/1508.06576)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Setup\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras.applications import vgg19\n",
    "\n",
    "base_image_path = keras.utils.get_file(\"paris.jpg\", \"https://i.imgur.com/F28w3Ac.jpg\")\n",
    "style_reference_image_path = keras.utils.get_file(\n",
    "    \"starry_night.jpg\", \"https://i.imgur.com/9ooB60I.jpg\"\n",
    ")\n",
    "result_prefix = \"paris_generated\"\n",
    "\n",
    "# Weights of the different loss components\n",
    "total_variation_weight = 1e-6\n",
    "style_weight = 1e-6\n",
    "content_weight = 2.5e-8\n",
    "\n",
    "# Dimensions of the generated picture.\n",
    "width, height = keras.preprocessing.image.load_img(base_image_path).size\n",
    "img_nrows = 400\n",
    "img_ncols = int(width * img_nrows / height)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Let's take a look at our base (content) image and our style reference image\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "from IPython.display import Image, display\n",
    "\n",
    "display(Image(base_image_path))\n",
    "display(Image(style_reference_image_path))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Image preprocessing / deprocessing utilities\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "\n",
    "def preprocess_image(image_path):\n",
    "    # Util function to open, resize and format pictures into appropriate tensors\n",
    "    img = keras.preprocessing.image.load_img(\n",
    "        image_path, target_size=(img_nrows, img_ncols)\n",
    "    )\n",
    "    img = keras.preprocessing.image.img_to_array(img)\n",
    "    img = np.expand_dims(img, axis=0)\n",
    "    img = vgg19.preprocess_input(img)\n",
    "    return tf.convert_to_tensor(img)\n",
    "\n",
    "\n",
    "def deprocess_image(x):\n",
    "    # Util function to convert a tensor into a valid image\n",
    "    x = x.reshape((img_nrows, img_ncols, 3))\n",
    "    # Remove zero-center by mean pixel\n",
    "    x[:, :, 0] += 103.939\n",
    "    x[:, :, 1] += 116.779\n",
    "    x[:, :, 2] += 123.68\n",
    "    # 'BGR'->'RGB'\n",
    "    x = x[:, :, ::-1]\n",
    "    x = np.clip(x, 0, 255).astype(\"uint8\")\n",
    "    return x\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Compute the style transfer loss\n",
    "\n",
    "First, we need to define 4 utility functions:\n",
    "\n",
    "- `gram_matrix` (used to compute the style loss)\n",
    "- The `style_loss` function, which keeps the generated image close to the local textures\n",
    "of the style reference image\n",
    "- The `content_loss` function, which keeps the high-level representation of the\n",
    "generated image close to that of the base image\n",
    "- The `total_variation_loss` function, a regularization loss which keeps the generated\n",
    "image locally-coherent\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "# The gram matrix of an image tensor (feature-wise outer product)\n",
    "\n",
    "\n",
    "def gram_matrix(x):\n",
    "    x = tf.transpose(x, (2, 0, 1))\n",
    "    features = tf.reshape(x, (tf.shape(x)[0], -1))\n",
    "    gram = tf.matmul(features, tf.transpose(features))\n",
    "    return gram\n",
    "\n",
    "\n",
    "# The \"style loss\" is designed to maintain\n",
    "# the style of the reference image in the generated image.\n",
    "# It is based on the gram matrices (which capture style) of\n",
    "# feature maps from the style reference image\n",
    "# and from the generated image\n",
    "\n",
    "\n",
    "def style_loss(style, combination):\n",
    "    S = gram_matrix(style)\n",
    "    C = gram_matrix(combination)\n",
    "    channels = 3\n",
    "    size = img_nrows * img_ncols\n",
    "    return tf.reduce_sum(tf.square(S - C)) / (4.0 * (channels ** 2) * (size ** 2))\n",
    "\n",
    "\n",
    "# An auxiliary loss function\n",
    "# designed to maintain the \"content\" of the\n",
    "# base image in the generated image\n",
    "\n",
    "\n",
    "def content_loss(base, combination):\n",
    "    return tf.reduce_sum(tf.square(combination - base))\n",
    "\n",
    "\n",
    "# The 3rd loss function, total variation loss,\n",
    "# designed to keep the generated image locally coherent\n",
    "\n",
    "\n",
    "def total_variation_loss(x):\n",
    "    a = tf.square(\n",
    "        x[:, : img_nrows - 1, : img_ncols - 1, :] - x[:, 1:, : img_ncols - 1, :]\n",
    "    )\n",
    "    b = tf.square(\n",
    "        x[:, : img_nrows - 1, : img_ncols - 1, :] - x[:, : img_nrows - 1, 1:, :]\n",
    "    )\n",
    "    return tf.reduce_sum(tf.pow(a + b, 1.25))\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "Next, let's create a feature extraction model that retrieves the intermediate activations\n",
    "of VGG19 (as a dict, by name).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "# Build a VGG19 model loaded with pre-trained ImageNet weights\n",
    "model = vgg19.VGG19(weights=\"imagenet\", include_top=False)\n",
    "\n",
    "# Get the symbolic outputs of each \"key\" layer (we gave them unique names).\n",
    "outputs_dict = dict([(layer.name, layer.output) for layer in model.layers])\n",
    "\n",
    "# Set up a model that returns the activation values for every layer in\n",
    "# VGG19 (as a dict).\n",
    "feature_extractor = keras.Model(inputs=model.inputs, outputs=outputs_dict)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "Finally, here's the code that computes the style transfer loss.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "# List of layers to use for the style loss.\n",
    "style_layer_names = [\n",
    "    \"block1_conv1\",\n",
    "    \"block2_conv1\",\n",
    "    \"block3_conv1\",\n",
    "    \"block4_conv1\",\n",
    "    \"block5_conv1\",\n",
    "]\n",
    "# The layer to use for the content loss.\n",
    "content_layer_name = \"block5_conv2\"\n",
    "\n",
    "\n",
    "def compute_loss(combination_image, base_image, style_reference_image):\n",
    "    input_tensor = tf.concat(\n",
    "        [base_image, style_reference_image, combination_image], axis=0\n",
    "    )\n",
    "    features = feature_extractor(input_tensor)\n",
    "\n",
    "    # Initialize the loss\n",
    "    loss = tf.zeros(shape=())\n",
    "\n",
    "    # Add content loss\n",
    "    layer_features = features[content_layer_name]\n",
    "    base_image_features = layer_features[0, :, :, :]\n",
    "    combination_features = layer_features[2, :, :, :]\n",
    "    loss = loss + content_weight * content_loss(\n",
    "        base_image_features, combination_features\n",
    "    )\n",
    "    # Add style loss\n",
    "    for layer_name in style_layer_names:\n",
    "        layer_features = features[layer_name]\n",
    "        style_reference_features = layer_features[1, :, :, :]\n",
    "        combination_features = layer_features[2, :, :, :]\n",
    "        sl = style_loss(style_reference_features, combination_features)\n",
    "        loss += (style_weight / len(style_layer_names)) * sl\n",
    "\n",
    "    # Add total variation loss\n",
    "    loss += total_variation_weight * total_variation_loss(combination_image)\n",
    "    return loss\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Add a tf.function decorator to loss & gradient computation\n",
    "\n",
    "To compile it, and thus make it fast.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "\n",
    "@tf.function\n",
    "def compute_loss_and_grads(combination_image, base_image, style_reference_image):\n",
    "    with tf.GradientTape() as tape:\n",
    "        loss = compute_loss(combination_image, base_image, style_reference_image)\n",
    "    grads = tape.gradient(loss, combination_image)\n",
    "    return loss, grads\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## The training loop\n",
    "\n",
    "Repeatedly run vanilla gradient descent steps to minimize the loss, and save the\n",
    "resulting image every 100 iterations.\n",
    "\n",
    "We decay the learning rate by 0.96 every 100 steps.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "optimizer = keras.optimizers.SGD(\n",
    "    keras.optimizers.schedules.ExponentialDecay(\n",
    "        initial_learning_rate=100.0, decay_steps=100, decay_rate=0.96\n",
    "    )\n",
    ")\n",
    "\n",
    "base_image = preprocess_image(base_image_path)\n",
    "style_reference_image = preprocess_image(style_reference_image_path)\n",
    "combination_image = tf.Variable(preprocess_image(base_image_path))\n",
    "\n",
    "iterations = 4000\n",
    "for i in range(1, iterations + 1):\n",
    "    loss, grads = compute_loss_and_grads(\n",
    "        combination_image, base_image, style_reference_image\n",
    "    )\n",
    "    optimizer.apply_gradients([(grads, combination_image)])\n",
    "    if i % 100 == 0:\n",
    "        print(\"Iteration %d: loss=%.2f\" % (i, loss))\n",
    "        img = deprocess_image(combination_image.numpy())\n",
    "        fname = result_prefix + \"_at_iteration_%d.png\" % i\n",
    "        keras.preprocessing.image.save_img(fname, img)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "After 4000 iterations, you get the following result:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "display(Image(result_prefix + \"_at_iteration_4000.png\"))\n"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "neural_style_transfer",
   "private_outputs": false,
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}